当前位置:  开发笔记 > 编程语言 > 正文

`sample_weight`对`DecisionTreeClassifier`在sklearn中的工作方式有何作用?

如何解决《`sample_weight`对`DecisionTreeClassifier`在sklearn中的工作方式有何作用?》经验,为你挑选了2个好方法。

我从这个文档中读到:

"可以通过从每个类中抽取相同数量的样本来完成类平衡,或者最好通过将每个类的样本权重(sample_weight)的总和归一化为相同的值."

但是,我仍然不清楚它是如何工作的.如果我设置sample_weight一个只有两个可能值的数组,1's和2's,这是否意味着带有2's的样本1在进行装袋时的采样频率是采样的两倍?我想不出一个实际的例子.



1> Matt Hancock..:

所以我花了一点时间看一下sklearn来源,因为我实际上已经意味着试着自己解决这个问题一段时间了.我为这个长度道歉,但我不知道如何更简单地解释它.


一些快速的预赛:

假设我们对K类有分类问题.在由决策树的节点表示的特征空间的区域中,回想一下,通过使用该区域中的类的概率来量化不均匀性来测量该区域的"杂质".通常,我们估计:

Pr(Class=k) = #(examples of class k in region) / #(total examples in region)

杂质测量作为输入,类概率数组:

[Pr(Class=1), Pr(Class=2), ..., Pr(Class=K)]

并吐出一个数字,它告诉你"不纯"或特征空间区域是如何不同类的.例如,两类问题的基尼系数衡量标准2*p*(1-p),其中p = Pr(Class=1)1-p=Pr(Class=2).


现在,基本上你的问题的简短答案是:

sample_weight增加了概率数组中的概率估计 ......这增加了杂质测量...这增加了节点的分裂方式......这增加了树的构建方式......这增加了如何将特征空间切割为分类.

我相信通过例子可以说明这一点.


首先考虑输入为1维的以下2类问题:

from sklearn.tree import DecisionTreeClassifier as DTC

X = [[0],[1],[2]] # 3 simple training examples
Y = [ 1,  2,  1 ] # class labels

dtc = DTC(max_depth=1)

所以,我们只会看到一个只有一个根节点和两个孩子的树.请注意,默认杂质测量gini度量.


案例1:没有 sample_weight
dtc.fit(X,Y)
print dtc.tree_.threshold
# [0.5, -2, -2]
print dtc.tree_.impurity
# [0.44444444, 0, 0.5]

threshold数组中的第一个值告诉我们第一个训练示例被发送到左子节点,第二个和第三个训练示例被发送到右子节点.最后两个值threshold是占位符,将被忽略.该impurity数组分别告诉我们父节点,左节点和右节点的计算杂质值.

在父节点中p = Pr(Class=1) = 2. / 3.,这样gini = 2*(2.0/3.0)*(1.0/3.0) = 0.444.....您也可以确认子节点杂质.


案例2:与 sample_weight

现在,让我们试试:

dtc.fit(X,Y,sample_weight=[1,2,3])
print dtc.tree_.threshold
# [1.5, -2, -2]
print dtc.tree_.impurity
# [0.44444444, 0.44444444, 0.]

您可以看到功能阈值不同.sample_weight也会影响每个节点的杂质测量.具体地,在概率估计中,由于我们提供的样本权重,第一训练示例被计数相同,第二训练示例被计数加倍,第三训练示例被计为三倍.

父节点区域中的杂质是相同的.这只是一个巧合.我们可以直接计算它:

p = Pr(Class=1) = (1+3) / (1+2+3) = 2.0/3.0

基尼系数衡量4/9如下.

现在,您可以从所选阈值中看到第一个和第二个训练样例被发送到左子节点,而第三个训练示例被发送到右侧.我们看到杂质计算4/9也在左子节点中,因为:

p = Pr(Class=1) = 1 / (1+2) = 1/3.

正确孩子的零杂质仅归因于该地区的一个训练样例.

您可以使用非整数样本重量来扩展它.我建议尝试类似的东西sample_weight = [1,2,2.5],并确认计算出的杂质.

希望这可以帮助!



2> Chris Farr..:

如果有人像我一样正在寻找计算sample_weight的方法,您可能会发现这很方便。

sklearn.utils.class_weight.compute_sample_weight

推荐阅读
个性2402852463
这个屌丝很懒,什么也没留下!
DevBox开发工具箱 | 专业的在线开发工具网站    京公网安备 11010802040832号  |  京ICP备19059560号-6
Copyright © 1998 - 2020 DevBox.CN. All Rights Reserved devBox.cn 开发工具箱 版权所有